I’m trying to rewrite XTTS in JAX to understand how it works.
We are going to implement the HiFiGAN used in (Casanova et al. 2024), a Text to Speech model written by the defunct Coqai company. The HiFiGAN model was first proposed in (Kong, Kim, and Bae 2020). The role of this model is to take Mel Spectrograms that represent a sequence of speech and transform them into a wav that we can listen to. This is what comes at the end of the XTTS model that will generate the spectrograms feeding the model. HiFiGaN is composed of a generator and two discriminators.
Generator
This part of the model takes in the Mel-Spectrogram and iteratively convolves the input into the desired output shape, here a \([1 \times N]\) tensor representing the time-series (a list of points) giving the amplitude of the sound. Passing this with a sampling rate to audio programs would allow us to listen to it !
Discriminators
Since we want the output to be as real as possible, we can use other models to learn to learn to discern the output of our generator from other real inputs. Basically, if this model can’t see the difference between real or fake audio then either our generator is doing a good job or the discriminator is doing a bad one !
Here the paper proposes two: One that looks at spaced out points to find patterns that could help discern real from fake, and one that looks at different scales. We can think of this second one as an inspector looking at a modified image for artifacts in hair, or on a bigger scale how the eyes are placed relative to the face. Anyways.
What we hope happens
2 Goal 🎯
Get intelligable speech coming out of the HiFiGAN with reasonable (1h tops) amount of training on a NVIDIA L40.
Since our final goal is to recreate a 1 to 1 version of the VQVAE used in XTTS, we’ll hardcode a lot of things to minimize issues.
We can now move onto implementing the rest of the Generator, which basically just uses a bunch of different ResBlocks at varying kernel sizes and strides to capture as much information to then transform slowly into a waveform (i.e. an array of points with a sampling rate).
class MRF(eqx.Module): resblocks: listdef__init__(self, channel_in: int, kernel_sizes: list, dilations: list, key=None):if key isNone:raiseValueError("The 'key' parameter cannot be None.")self.resblocks = [ ResBlock(channel_in, kernel_size, dilation, key=y)for kernel_size, dilation, y inzip( kernel_sizes, dilations, jax.random.split(key, len(kernel_sizes)) ) ]def__call__(self, x): y =self.resblocks[0](x)for block inself.resblocks[1:]: y += block(x)return y /len(self.resblocks)
Finally, we can create each “MRF” i.e. list of ResBlocks and call it the Generator.
class Generator(eqx.Module): conv_pre: nn.Conv1d layers: list post_magic: nn.Conv1d norm = nn.WeightNormdef__init__(self, channels_in: int, channels_out: int, h_u=512, k_u=[16, 16, 4, 4], upsample_rate_decoder=[8, 8, 2, 2], k_r=[3, 7, 11], dilations=[1, 3, 5], key=None, ):if key isNone:raiseValueError("The 'key' parameter cannot be None.") key, grab = jax.random.split(key, 2)self.conv_pre = nn.Conv1d( channels_in, h_u, kernel_size=7, dilation=1, padding=3, key=grab )# This is where the magic happens. Upsample aggressively then more slowly. TODO could play around with this.# Then convolve one last time (Curious to see the weights to see if has good impact)self.layers = [ ( nn.ConvTranspose1d(int(h_u / (2**i)),int(h_u / (2** (i +1))), kernel_size=k, stride=u, padding="SAME", key=y, ), MRF( channel_in=int(h_u / (2** (i +1))), kernel_sizes=k_r, dilations=dilations, key=y, ), )for i, (k, u, y) inenumerate(zip(k_u, upsample_rate_decoder, jax.random.split(key, len(k_u))) ) ]self.post_magic = nn.Conv1d(int(h_u / (2**len(k_u))), channels_out, kernel_size=7, stride=1, padding=3, use_bias=False, key=key, )# self.post_magic = nn.WeightNorm(self.post_magic,def__call__(self, x): y =self.norm(self.conv_pre)(x)for upsample, mrf inself.layers: y = jax.nn.leaky_relu(y, LRELU_SLOPE) y =self.norm(upsample)(y) # Upsample y = mrf(y) y = jax.nn.leaky_relu(y, LRELU_SLOPE) y =self.norm(self.post_magic)(y) y = jax.nn.tanh(y)return y
😮💨 Ok that was FAT ! We now should have a model that can take in images like mel spectrograms and transform them into waves. Let’s quickly check it’s at least outputting the right dimensions given an input. Based on our upsample_rate_decoder any mel spectrogram of the form \([Melbins \times length]\) becomes $[1 length ] $
We can now move onto writing up the discriminators that will attempt to discern fake from real outputs.
The two discriminators used by XTTS and mentionned in the OG paper (Kong, Kim, and Bae 2020)
Let’s first start by writing the Periodic Discriminator. Both discriminators are actually a set of models with varying input sizes. The code below is mainly a rewrite from the code available here: (jik876 2020)
Notice that we’re also returning the activations of each intermediate layer. This is because we’ll implement a loss that doesn’t just compare the output but the intermediate outputs between real and fake inputs. We can now define the discriminator that looks at different scales.
Finally, we can define wrappers for both that will contain various periods or scales. Notice that the periods have prime numbers to avoid overlapping as much as possible
class MultiScaleDiscriminator(eqx.Module): discriminators: list meanpool: nn.AvgPool1d = nn.AvgPool1d(4, 2, padding=2)# TODO need to add spectral norm thingsdef__init__(self, key: jax.Array = jax.random.PRNGKey(0)): key1, key2, key3 = jax.random.split(key, 3)self.discriminators = [ DiscriminatorS(key1), DiscriminatorS(key2), DiscriminatorS(key3), ]# self.meanpool = nn.AvgPool1d(4, 2, padding=2)def__call__(self, x): preds = [] fmaps = []for disc inself.discriminators: pred, fmap = disc(x) preds.append(pred) fmaps.append(fmap) x =self.meanpool(x) # Subtle way of scaling things down by 2return preds, fmapsclass MultiPeriodDiscriminator(eqx.Module): discriminators: listdef__init__(self, periods=[2, 3, 5, 7, 11], key: jax.Array = jax.random.PRNGKey(0) ):self.discriminators = [ DiscriminatorP(period, key=y)for period, y inzip(periods, jax.random.split(key, len(periods))) ]def__call__(self, x): preds = [] fmaps = []for disc inself.discriminators: pred, fmap = disc(x) preds.append(pred) fmaps.append(fmap)return preds, fmaps
Congratulations 🎉 We’ve now implemented a HiFiGAN ! Now “all” 🤏 that remains to do is training it.
4 Training
4.1 Losses and gradients
We’re going to have multiple losses to deal with - a total of 3 actually. The first is the MSE, comparing what our generator produces compared to the ideal output. The next two will be the discriminators. Finally, we need to also define losses for the discriminators, which in our case are simply MSE between real and 1, and fake and 0.
Now that we have a folder full of mel spectrograms ready to be fed into our model, we can start the training ! Below, we do multiple things:
Initialize the model, the optimizer that will nugde it based on the losses our function returns, logging with tensorboard and saving the model every epoch.
I haven’t gotten around to training the model from A to Z but we can see that the initial iterations are quite chaotic.
Results after 1 iteration
5 References
Casanova, Edresson, Kelly Davis, Eren Gölge, Görkem Göknar, Iulian Gulea, Logan Hart, Aya Aljafari, et al. 2024. “XTTS: A Massively Multilingual Zero-Shot Text-to-Speech Model.”https://arxiv.org/abs/2406.04904.
Kong, Jungil, Jaehyeon Kim, and Jaekyoung Bae. 2020. “HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis.”https://arxiv.org/abs/2010.05646.